{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d34f8c48-90fc-4981-8d2b-b47724c2a6dd",
   "metadata": {
    "vscode": {
     "languageId": "raw"
    }
   },
   "source": [
    "# Client Examples\n",
    "\n",
    "Client provides a uniform interface for interacting with LLMs from various providers. It adapts the official python libraries from providers such as Mistral, OpenAI, Groq, Anthropic, AWS, etc to conform to the OpenAI chat completion interface. It directly calls the REST endpoints in some cases.\n",
    "\n",
    "Below are some examples of how to use Client to interact with different LLMs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "initial_id",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-07-04T15:30:02.064319Z",
     "start_time": "2024-07-04T15:30:02.051986Z"
    }
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "from dotenv import load_dotenv, find_dotenv\n",
    "\n",
    "sys.path.append('../../aisuite')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f75736ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "def configure_environment(additional_env_vars=None):\n",
    "    \"\"\"\n",
    "    Load environment variables from .env file and apply any additional variables.\n",
    "    :param additional_env_vars: A dictionary of additional environment variables to apply.\n",
    "    \"\"\"\n",
    "    # Load from .env file if available\n",
    "    load_dotenv(find_dotenv())\n",
    "\n",
    "    # Apply additional environment variables\n",
    "    if additional_env_vars:\n",
    "        for key, value in additional_env_vars.items():\n",
    "            os.environ[key] = value\n",
    "\n",
    "# Define additional API keys and credentials\n",
    "additional_keys = {\n",
    "    'GROQ_API_KEY': 'xxx',\n",
    "    'AWS_ACCESS_KEY_ID': 'xxx',\n",
    "    'AWS_SECRET_ACCESS_KEY': 'xxx',\n",
    "    'ANTHROPIC_API_KEY': 'xxx',\n",
    "}\n",
    "\n",
    "# Configure environment\n",
    "configure_environment(additional_env_vars=additional_keys)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4de3a24f",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-07-04T15:31:12.914321Z",
     "start_time": "2024-07-04T15:31:12.796445Z"
    }
   },
   "outputs": [],
   "source": [
    "import aisuite as ai\n",
    "\n",
    "client = ai.Client()\n",
    "messages = [\n",
    "    {\"role\": \"system\", \"content\": \"Respond in Pirate English. Always try to include the phrase - No rum No fun.\"},\n",
    "    {\"role\": \"user\", \"content\": \"Tell me a joke about Captain Jack Sparrow\"},\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "520a6879",
   "metadata": {},
   "outputs": [],
   "source": [
    "# print(os.environ[\"ANTHROPIC_API_KEY\"])\n",
    "anthropic_claude_3_opus = \"anthropic:claude-3-5-sonnet-20240620\"\n",
    "response = client.chat.completions.create(model=anthropic_claude_3_opus, messages=messages)\n",
    "print(response.choices[0].message.content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9893c7e4-799a-42c9-84de-f9e643044462",
   "metadata": {},
   "outputs": [],
   "source": [
    "aws_bedrock_llama3_8b = \"aws:meta.llama3-1-8b-instruct-v1:0\"\n",
    "response = client.chat.completions.create(model=aws_bedrock_llama3_8b, messages=messages)\n",
    "print(response.choices[0].message.content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e46c20a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# IMP NOTE: Azure expects model endpoint to be passed in the format of \"azure:<model_name>\".\n",
    "# The model name is the deployment name in Project/Deployments.\n",
    "# In the exmaple below, the model is \"mistral-large-2407\", but the name given to the\n",
    "# deployment is \"aisuite-mistral-large-2407\" under the deployments section in Azure.\n",
    "client.configure({\"azure\" : {\n",
    "  \"api_key\": os.environ[\"AZURE_API_KEY\"],\n",
    "  \"base_url\": \"https://aisuite-mistral-large-2407.westus3.models.ai.azure.com/v1/\",\n",
    "}});\n",
    "azure_model = \"azure:aisuite-mistral-large-2407\"\n",
    "response = client.chat.completions.create(model=azure_model, messages=messages)\n",
    "print(response.choices[0].message.content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f996b121",
   "metadata": {},
   "outputs": [],
   "source": [
    "# HuggingFace expects the model to be passed in the format of \"huggingface:<model_name>\".\n",
    "# The model name is the full name of the model in HuggingFace.\n",
    "# In the exmaple below, the model is \"mistralai/Mistral-7B-Instruct-v0.3\".\n",
    "# The model is deployed as serverless inference endpoint in HuggingFace.\n",
    "hf_model = \"huggingface:mistralai/Mistral-7B-Instruct-v0.3\"\n",
    "response = client.chat.completions.create(model=hf_model, messages=messages)\n",
    "print(response.choices[0].message.content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9b2aad6-8603-4227-9566-778f714eb0b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# Groq expects the model to be passed in the format of \"groq:<model_name>\".\n",
    "# The model name is the full name of the model in Groq.\n",
    "# In the exmaple below, the model is \"llama3-8b-8192\".\n",
    "groq_llama3_8b = \"groq:llama3-8b-8192\"\n",
    "# groq_llama3_70b = \"groq:llama3-70b-8192\"\n",
    "response = client.chat.completions.create(model=groq_llama3_8b, messages=messages)\n",
    "print(response.choices[0].message.content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6819ac17",
   "metadata": {},
   "outputs": [],
   "source": [
    "ollama_tinyllama = \"ollama:tinyllama\"\n",
    "ollama_phi3mini = \"ollama:phi3:mini\"\n",
    "response = client.chat.completions.create(model=ollama_phi3mini, messages=messages, temperature=0.75)\n",
    "print(response.choices[0].message.content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a94961b2bddedbb",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-07-04T15:31:39.472675Z",
     "start_time": "2024-07-04T15:31:38.283368Z"
    }
   },
   "outputs": [],
   "source": [
    "mistral_7b = \"mistral:open-mistral-7b\"\n",
    "response = client.chat.completions.create(model=mistral_7b, messages=messages, temperature=0.2)\n",
    "print(response.choices[0].message.content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "611210a4dc92845f",
   "metadata": {},
   "outputs": [],
   "source": [
    "openai_gpt35 = \"openai:gpt-3.5-turbo\"\n",
    "response = client.chat.completions.create(model=openai_gpt35, messages=messages, temperature=0.75)\n",
    "print(response.choices[0].message.content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "321783ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "fireworks_model = \"fireworks:accounts/fireworks/models/llama-v3p2-3b-instruct\"\n",
    "response = client.chat.completions.create(model=fireworks_model, messages=messages, temperature=0.75, presence_penalty=0.5, frequency_penalty=0.5)\n",
    "print(response.choices[0].message.content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e30e5ae0",
   "metadata": {},
   "outputs": [],
   "source": [
    "togetherai_model = \"together:meta-llama/Llama-3.2-3B-Instruct-Turbo\"\n",
    "response = client.chat.completions.create(model=togetherai_model, messages=messages, temperature=0.75, top_p=0.7, top_k=50)\n",
    "print(response.choices[0].message.content)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
